package com.jstarcraft.ai.jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.calibration.BinaryScoreClassifier;
import com.jstarcraft.ai.jsat.parameters.Parameterized;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;

/**
 * Modest Ada Boost is a generalization of Discrete Ada Boost that attempts to
 * reduce the generalization error and avoid over-fitting. Empirically,
 * ModestBoost usually maintains a higher training-set error, and may take more
 * iterations to obtain the same test set error as other algorithms, but doesn't
 * not increase as much after it reaches the minimum error - which should make
 * it easier to obtain the higher accuracy. <br>
 * See: <br>
 * Vezhnevets, A.,&amp;Vezhnevets, V. (2005). <i>“Modest AdaBoost” – Teaching
 * AdaBoost to Generalize Better</i>. GraphiCon. Novosibirsk Akademgorodok,
 * Russia. Retrieved from
 * <a href="http://www.inf.ethz.ch/personal/vezhneva/Pubs/ModestAdaBoost.pdf">
 * here</a>
 * 
 * @author Edward Raff
 */
public class ModestAdaBoost implements Classifier, Parameterized, BinaryScoreClassifier {

    private static final long serialVersionUID = 8223388561185098909L;
    private Classifier weakLearner;
    private int maxIterations;
    /**
     * The list of weak hypothesis
     */
    protected List<Classifier> hypoths;
    /**
     * The weights for each weak learner
     */
    protected DoubleArrayList hypWeights;
    protected CategoricalData predicting;

    /**
     * Creates a new ModestBoost learner
     * 
     * @param weakLearner   the weak learner to use
     * @param maxIterations the maximum number of boosting iterations
     */
    public ModestAdaBoost(Classifier weakLearner, int maxIterations) {
        setWeakLearner(weakLearner);
        setMaxIterations(maxIterations);
    }

    /**
     * Copy constructor
     * 
     * @param toClone the object to clone
     */
    protected ModestAdaBoost(ModestAdaBoost toClone) {
        this(toClone.weakLearner.clone(), toClone.maxIterations);
        if (toClone.hypWeights != null) {
            this.hypWeights = new DoubleArrayList(toClone.hypWeights);
            this.hypoths = new ArrayList<Classifier>(toClone.maxIterations);
            for (Classifier weak : toClone.hypoths)
                this.hypoths.add(weak.clone());
            this.predicting = toClone.predicting.clone();
        }
    }

    /**
     * 
     * @return a list of the models that are in this ensemble.
     */
    public List<Classifier> getModels() {
        return Collections.unmodifiableList(hypoths);
    }

    /**
     * 
     * @return a list of the models weights that are in this ensemble.
     */
    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(hypWeights);
    }

    /**
     * Returns the maximum number of iterations used
     * 
     * @return the maximum number of iterations used
     */
    public int getMaxIterations() {
        return maxIterations;
    }

    /**
     * Sets the maximal number of boosting iterations that may be performed
     * 
     * @param maxIterations the maximum number of iterations
     */
    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1)
            throw new IllegalArgumentException("Iterations must be positive, not " + maxIterations);
        this.maxIterations = maxIterations;
    }

    /**
     * Returns the weak learner currently being used by this method.
     * 
     * @return the weak learner currently being used by this method.
     */
    public Classifier getWeakLearner() {
        return weakLearner;
    }

    /**
     * Sets the weak learner used during training.
     * 
     * @param weakLearner the weak learner to use
     */
    public void setWeakLearner(Classifier weakLearner) {
        if (!weakLearner.supportsWeightedData())
            throw new IllegalArgumentException("WeakLearner must support weighted data to be boosted");
        this.weakLearner = weakLearner;
    }

    @Override
    public double getScore(DataPoint dp) {
        double score = 0;
        for (int i = 0; i < hypoths.size(); i++)
            score += (hypoths.get(i).classify(dp).getProb(1) * 2 - 1) * hypWeights.getDouble(i);
        return score;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (predicting == null)
            throw new RuntimeException("Classifier has not been trained yet");

        CategoricalResults cr = new CategoricalResults(predicting.getNumOfCategories());

        double score = getScore(data);
        if (score < 0)
            cr.setProb(0, 1.0);
        else
            cr.setProb(1, 1.0);
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        predicting = dataSet.getPredicting();
        hypWeights = new DoubleArrayList(maxIterations);
        hypoths = new ArrayList<Classifier>(maxIterations);
        final int N = dataSet.size();

        double[] D_inv = new double[N];
        double[] D = new double[N];

        ClassificationDataSet cds = dataSet.shallowClone();
        Arrays.fill(D, 1.0 / N);
        for (int i = 0; i < N; i++)
            cds.setWeight(i, D[0]);// Scaled, they are all 1
        double weightSum = 1;

        double[] H_cur = new double[N];

        for (int t = 0; t < maxIterations; t++) {
            Classifier weak = weakLearner.clone();
            weak.train(cds, parallel);

            double invSum = 0;
            for (int i = 0; i < N; i++)
                invSum += (D_inv[i] = 1 - D[i]);

            for (int i = 0; i < N; i++)
                D_inv[i] /= invSum;
            double p_d = 0, p_id = 0, n_d = 0, n_id = 0;

            for (int i = 0; i < N; i++) {
                H_cur[i] = (weak.classify(cds.getDataPoint(i)).getProb(1) * 2 - 1);
                double outPut = Math.signum(H_cur[i]);
                int c = cds.getDataPointCategory(i);
                if (c == 1)// positive example case
                {
                    p_d += outPut * D[i];
                    p_id += outPut * D_inv[i];
                } else {
                    n_d += outPut * D[i];
                    n_id += outPut * D_inv[i];
                }

            }

            double alpha_m = p_d * (1 - p_id) - n_d * (1 - n_id);

            if (Math.signum(alpha_m) != Math.signum(p_d - n_d) || Math.abs((p_d - n_d)) < 1e-6 || alpha_m <= 0)
                return;

            weightSum = 0;
            for (int i = 0; i < N; i++) {
                double w_i = cds.getWeight(i);
                int y_i = cds.getDataPointCategory(i) * 2 - 1;
                w_i *= Math.exp(-y_i * alpha_m * H_cur[i]);
                if (Double.isInfinite(w_i))
                    w_i = 1;// Let it grow back
                else if (w_i <= 0)
                    w_i = 1e-3 / N;// Dont let it go quit to zero
                weightSum += w_i;
                cds.setWeight(i, w_i);
            }

            for (int i = 0; i < N; i++)
                cds.setWeight(i, Math.max(cds.getWeight(i) / weightSum, 1e-10));

            hypWeights.add(alpha_m);
            hypoths.add(weak);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public ModestAdaBoost clone() {
        return new ModestAdaBoost(this);
    }
}
